package data_deepprocessing.algorithm.crfs;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.regex.Pattern;

import cc.mallet.fst.CRF;
import cc.mallet.fst.CRFOptimizableByLabelLikelihood;
import cc.mallet.fst.CRFTrainerByValueGradients;
import cc.mallet.fst.MaxLatticeDefault;
import cc.mallet.fst.MultiSegmentationEvaluator;
import cc.mallet.fst.NoopTransducerTrainer;
import cc.mallet.fst.Transducer;
import cc.mallet.fst.TransducerEvaluator;
import cc.mallet.fst.TransducerTrainer;
import cc.mallet.optimize.Optimizable;
import cc.mallet.pipe.Pipe;
import cc.mallet.pipe.SerialPipes;
import cc.mallet.pipe.SimpleTaggerSentence2TokenSequence;
import cc.mallet.pipe.TokenSequence2FeatureVectorSequence;
import cc.mallet.pipe.iterator.LineGroupIterator;
import cc.mallet.pipe.tsf.FeaturesInWindow;
import cc.mallet.pipe.tsf.LexiconMembership;
import cc.mallet.pipe.tsf.OffsetConjunctions;
import cc.mallet.pipe.tsf.RegexMatches;
import cc.mallet.pipe.tsf.TokenTextCharNGrams;
import cc.mallet.pipe.tsf.TokenTextCharPrefix;
import cc.mallet.pipe.tsf.TokenTextCharSuffix;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Sequence;

public class CrfApp {
	
	private  InstanceList trainInstances;
	
	private  InstanceList testInstances;
	
	private  HashMap<Integer,File> testmap=new HashMap<Integer,File>();
	

	/**
	 * 加载训练数据、测试数据并加载特征模板
	 *
	 * @param trainDir
	 * @param testDir
	 * @throws Exception 
	 */
	public void loadData(String trainDir, String testDir,String dicPath,List<String> patternList)
			throws Exception {	

		File[] trainfiles = new File(trainDir).listFiles();//读文件，改成读取数据库
		File[] testfiles = new File(testDir).listFiles();

		ArrayList<Pipe> pipes = new ArrayList<Pipe>();

		// 加入滑动窗口大小特征
		int[][] conjunctions = new int[2][];    // 观察前后关联特征
		conjunctions[0] = new int[] { -1 };
		conjunctions[1] = new int[] { 1 };
	//	conjunctions[2]=new int[]{-2,-1};

		pipes.add(new SimpleTaggerSentence2TokenSequence());// 把S-O加入 矩阵的形式，把token作为Feature

		pipes.add(new OffsetConjunctions(conjunctions));// S-S加入

		// 加入特征窗口特征（上下文窗口特征）
		pipes.add(new FeaturesInWindow("PREV-", -2, 2));//十特征模板集。。。。
		//添加指示词特征
		if(patternList!=null && patternList.size()>0){
		 for (String pattern : patternList) {
			   pipes.add(new RegexMatches("P-S", Pattern.compile("["+pattern+"].*")));// 症状指示词
			  }
		}
		
		pipes.add(new RegexMatches("P-S", Pattern.compile("[出现].*")));// 症状指示词
		pipes.add(new RegexMatches("P-S", Pattern.compile("[伴随].*")));// 症状指示词
		pipes.add(new RegexMatches("P-S", Pattern.compile("[伴].*")));// 症状指示词
		//pipes.add(new RegexMatches("P-S", Pattern.compile("[].*")));// 症状指示词
		pipes.add(new RegexMatches("P-S", Pattern.compile("[导致].*")));// 症状指示词
		pipes.add(new RegexMatches("P-D", Pattern.compile("[诊为].*")));// 疾病名指示词
		pipes.add(new RegexMatches("P-D", Pattern.compile("[诊断为].*")));// 疾病名指示词
		pipes.add(new RegexMatches("P-D", Pattern.compile("[断为].*")));// 疾病名指示词
		pipes.add(new RegexMatches("P-I", Pattern.compile("[因].*")));// 诱因名指示词
		pipes.add(new RegexMatches("P-I", Pattern.compile("[由于].*")));
		pipes.add(new RegexMatches("E-D", Pattern.compile("[病].*")));// 诱因指示词
		////
		// 加入实体前后缀大小特征
		pipes.add(new TokenTextCharPrefix("2PREFIX=", 2));
		pipes.add(new TokenTextCharPrefix("3PREFIX=", 3));
		pipes.add(new TokenTextCharPrefix("4PREFIX=", 4));
		pipes.add(new TokenTextCharSuffix("2SUFFIX=", 2));
		pipes.add(new TokenTextCharSuffix("3SUFFIX=", 3));
		pipes.add(new TokenTextCharSuffix("4SUFFIX=", 4));
		pipes.add(new TokenTextCharNGrams("CHARNGRAM=", new int[] { 2,3 }, true));//参数int[] gramsize

		pipes.add(new LexiconMembership("seeddic", new  FileReader(dicPath+"/seeddic.txt"), false));
		pipes.add(new LexiconMembership("B-S", new FileReader(dicPath+"/B-S.txt"),
				false));
		pipes.add(new LexiconMembership("E-S", new FileReader(dicPath+"/E-S.txt"),
				false));

		// 添加症状构词模式特征
//		pipes.add(new LexiconMembership("B-D", new FileReader(dicPath+"/B-D.txt"),
//				false));
//		pipes.add(new LexiconMembership("E-D", new FileReader(dicPath+"/E-D.txt"),
//				false));
//
//		//  添加症状构词模式特征
//		pipes.add(new LexiconMembership("B-I", new FileReader(dicPath+"/B-I.txt"),
//				false));
//		pipes.add(new LexiconMembership("E-I", new FileReader(dicPath+"/E-I.txt"),
//				false));
		pipes.add(new TokenSequence2FeatureVectorSequence(true, true)); // 把一个句子转换成特征向量序列
		// pipes.add(new Target2BIOFormat());
		// pipes.add(new Csv2FeatureVector());
		Pipe pipe = new SerialPipes(pipes);// 序列化,把pipes合并成一个pipe
		trainInstances = new InstanceList(pipe);
		testInstances = new InstanceList(pipe);
		//循环读取训练集的数据。。
		for (int i = 0; i < trainfiles.length; i++) {
			trainInstances.addThruPipe(new LineGroupIterator(new BufferedReader(
					new InputStreamReader(new FileInputStream(trainfiles[i]))), Pattern
					.compile("^\\s*$"), true));
		}

		for (int i = 0; i < testfiles.length; i++) {
			testmap.put(i, testfiles[i]);
			testInstances.addThruPipe(new LineGroupIterator(new BufferedReader(
					new InputStreamReader(new FileInputStream(testfiles[i]))), Pattern
					.compile("^\\s*$"), true));
		}
		System.out.println("testInstances: "+testInstances.size() + " " + testfiles.length);
		System.out.println("trainInstances: "+trainInstances.size() + " " + trainfiles.length);
	}
//	public void loadData(String trainDir, String testDir,String dicPath,List<TrainPatternFeature> patternList)
//			throws Exception {	
//		
//		File[] trainfiles = new File(trainDir).listFiles();//读文件，改成读取数据库
//		File[] testfiles = new File(testDir).listFiles();
//		
//		ArrayList<Pipe> pipes = new ArrayList<Pipe>();
//		
//		// 加入滑动窗口大小特征
//		int[][] conjunctions = new int[2][];    // 观察前后关联特征
//		conjunctions[0] = new int[] { -1 };
//		conjunctions[1] = new int[] { 1 };
//		//	conjunctions[2]=new int[]{-2,-1};
//		
//		pipes.add(new SimpleTaggerSentence2TokenSequence());// 把S-O加入 矩阵的形式，把token作为Feature
//		
//		pipes.add(new OffsetConjunctions(conjunctions));// S-S加入
//		
//		// 加入特征窗口特征（上下文窗口特征）
//		pipes.add(new FeaturesInWindow("PREV-", -2, 2));//十特征模板集。。。。
//		//添加指示词特征
//		if(patternList!=null&&patternList.size()>0){
//			for (int t=0;t<patternList.size();t++) {
//				String str=patternList.get(t).getPatternStr();
//				pipes.add(new RegexMatches("P-S", Pattern.compile("["+str+"].*")));// 症状指示词
//			}
//		}
//		
//		pipes.add(new RegexMatches("P-S", Pattern.compile("[出现].*")));// 症状指示词
//		pipes.add(new RegexMatches("P-S", Pattern.compile("[伴随].*")));// 症状指示词
//		pipes.add(new RegexMatches("P-S", Pattern.compile("[伴].*")));// 症状指示词
//		//pipes.add(new RegexMatches("P-S", Pattern.compile("[].*")));// 症状指示词
//		pipes.add(new RegexMatches("P-S", Pattern.compile("[导致].*")));// 症状指示词
//		pipes.add(new RegexMatches("P-D", Pattern.compile("[诊为].*")));// 疾病名指示词
//		pipes.add(new RegexMatches("P-D", Pattern.compile("[诊断为].*")));// 疾病名指示词
//		pipes.add(new RegexMatches("P-D", Pattern.compile("[断为].*")));// 疾病名指示词
//		pipes.add(new RegexMatches("P-I", Pattern.compile("[因].*")));// 诱因名指示词
//		pipes.add(new RegexMatches("P-I", Pattern.compile("[由于].*")));
//		pipes.add(new RegexMatches("E-D", Pattern.compile("[病].*")));// 诱因指示词
//		////
//		// 加入实体前后缀大小特征
//		pipes.add(new TokenTextCharPrefix("2PREFIX=", 2));
//		pipes.add(new TokenTextCharPrefix("3PREFIX=", 3));
//		pipes.add(new TokenTextCharPrefix("4PREFIX=", 4));
//		pipes.add(new TokenTextCharSuffix("2SUFFIX=", 2));
//		pipes.add(new TokenTextCharSuffix("3SUFFIX=", 3));
//		pipes.add(new TokenTextCharSuffix("4SUFFIX=", 4));
//		pipes.add(new TokenTextCharNGrams("CHARNGRAM=", new int[] { 2,3 }, true));//参数int[] gramsize
//		
//		pipes.add(new LexiconMembership("seeddic", new  FileReader(dicPath+"/seeddic.txt"), false));
//		pipes.add(new LexiconMembership("B-S", new FileReader(dicPath+"/B-S.txt"),
//				false));
//		pipes.add(new LexiconMembership("E-S", new FileReader(dicPath+"/E-S.txt"),
//				false));
//		
//		// 添加症状构词模式特征
//		pipes.add(new LexiconMembership("B-D", new FileReader(dicPath+"/B-D.txt"),
//				false));
//		pipes.add(new LexiconMembership("E-D", new FileReader(dicPath+"/E-D.txt"),
//				false));
//		
//		//  添加症状构词模式特征
//		pipes.add(new LexiconMembership("B-I", new FileReader(dicPath+"/B-I.txt"),
//				false));
//		pipes.add(new LexiconMembership("E-I", new FileReader(dicPath+"/E-I.txt"),
//				false));
//		pipes.add(new TokenSequence2FeatureVectorSequence(true, true)); // 把一个句子转换成特征向量序列
//		// pipes.add(new Target2BIOFormat());
//		// pipes.add(new Csv2FeatureVector());
//		Pipe pipe = new SerialPipes(pipes);// 序列化,把pipes合并成一个pipe
//		trainInstances = new InstanceList(pipe);
//		testInstances = new InstanceList(pipe);
//		//循环读取训练集的数据。。
//		for (int i = 0; i < trainfiles.length; i++) {
//			trainInstances.addThruPipe(new LineGroupIterator(new BufferedReader(
//					new InputStreamReader(new FileInputStream(trainfiles[i]))), Pattern
//					.compile("^\\s*$"), true));
//		}
//		
//		for (int i = 0; i < testfiles.length; i++) {
//			testmap.put(i, testfiles[i]);
//			testInstances.addThruPipe(new LineGroupIterator(new BufferedReader(
//					new InputStreamReader(new FileInputStream(testfiles[i]))), Pattern
//					.compile("^\\s*$"), true));
//		}
//	}

	/**
	 * 加载训练数据、测试数据并加载特征模板用于十重交叉验证
	 *
	 * @param trainDir
	 * @param testDir
	 * @throws FileNotFoundException
	 */
	public void loadData(List<File> trainfiles, List<File> testfiles,String dicPath,List<String> patternList)
			throws FileNotFoundException {
		ArrayList<Pipe> pipes = new ArrayList<Pipe>();

		// 加入滑动窗口大小特征
		int[][] conjunctions = new int[2][];// 观察前后关联特征
		conjunctions[0] = new int[] { -1 };
		conjunctions[1] = new int[] { 1 };

		pipes.add(new SimpleTaggerSentence2TokenSequence());// 把S-O加入 矩阵的形式
		pipes.add(new OffsetConjunctions(conjunctions));// S-S加入
		// 加入特征窗口特征
		pipes.add(new FeaturesInWindow("PREV-", -3,3));//上下文窗口特征，基本特征。
		//添加指示词特征
//		if(null!=patternList||(patternList.size()!=0)){
//			for (int t=0;t<patternList.size();t++) {
//				 String str=patternList.get(t).getPatternStr();
//				   pipes.add(new RegexMatches("P-S", Pattern.compile("["+str+"].*")));// 症状指示词
//				  }
//		}
		//添加指示词特征
		pipes.add(new RegexMatches("P-S", Pattern.compile("[出现].*")));// 症状指示词
		pipes.add(new RegexMatches("P-S", Pattern.compile("[伴随].*")));// 症状指示词
		pipes.add(new RegexMatches("P-S", Pattern.compile("[伴].*")));// 症状指示词
		pipes.add(new RegexMatches("P-S", Pattern.compile("[导致].*")));// 症状指示词
		pipes.add(new RegexMatches("P-D", Pattern.compile("[诊为].*")));// 疾病名指示词
		pipes.add(new RegexMatches("P-D", Pattern.compile("[诊断为].*")));// 疾病名指示词
		pipes.add(new RegexMatches("P-D", Pattern.compile("[断为].*")));// 疾病名指示词
		pipes.add(new RegexMatches("P-I", Pattern.compile("[因].*")));// 诱因名指示词
		pipes.add(new RegexMatches("P-I", Pattern.compile("[由于].*")));
		pipes.add(new RegexMatches("E-D", Pattern.compile("[病].*")));// 诱因指示词
	
		// 加入实体前后缀大小特征
		pipes.add(new TokenTextCharPrefix("2PREFIX=", 2));
		pipes.add(new TokenTextCharPrefix("3PREFIX=", 3));
		pipes.add(new TokenTextCharPrefix("4PREFIX=", 4));
		pipes.add(new TokenTextCharSuffix("2SUFFIX=", 2));
		pipes.add(new TokenTextCharSuffix("3SUFFIX=", 3));
		pipes.add(new TokenTextCharSuffix("4SUFFIX=", 4));
		pipes.add(new TokenTextCharNGrams("CHARNGRAM=", new int[] { 2,3 }, true));

		pipes.add(new LexiconMembership("seeddic", new  FileReader(dicPath+"/seeddic.txt"), false));
		//添加症状构词特征
		pipes.add(new LexiconMembership("B-S", new FileReader(dicPath+"/B-S.txt"),
				false));
		pipes.add(new LexiconMembership("E-S", new FileReader(dicPath+"/E-S.txt"),
				false));

//		// 添加疾病构词特征
//		pipes.add(new LexiconMembership("B-D", new FileReader(dicPath+"/B-D.txt"),
//				false));
//		pipes.add(new LexiconMembership("E-D", new FileReader(dicPath+"/E-D.txt"),
//				false));
//
//		// 添加诱因构词特征
//		pipes.add(new LexiconMembership("B-I", new FileReader(dicPath+"/B-I.txt"),
//				false));
//		pipes.add(new LexiconMembership("E-I", new FileReader(dicPath+"/E-I.txt"),
//				false));

		pipes.add(new TokenSequence2FeatureVectorSequence(true, true)); // 把一个句子转换成特征向量序列
		// pipes.add(new Target2BIOFormat());
		// pipes.add(new Csv2FeatureVector());
		Pipe pipe = new SerialPipes(pipes);// 序列化,把pipes合并成一个pipe
		trainInstances = new InstanceList(pipe);
		testInstances = new InstanceList(pipe);
		for (int i = 0; i < trainfiles.size(); i++) {
			trainInstances.addThruPipe(new LineGroupIterator(new BufferedReader(
					new InputStreamReader(new FileInputStream(trainfiles.get(i)))), Pattern
					.compile("^\\s*$"), true));
		}

		for (int i = 0; i < testfiles.size(); i++) {
			testInstances.addThruPipe(new LineGroupIterator(new BufferedReader(
					new InputStreamReader(new FileInputStream(testfiles.get(i)))), Pattern
					.compile("^\\s*$"), true));
			testmap.put(i, testfiles.get(i));
		}
		System.out.println("testInstances: "+testInstances.size() + " " + testfiles.size());
	}

	/**
	 * 训练模型
	 *
	 * @param trainingData
	 * @throws IOException
	 */
	public void trainModel(int iterations,String modelPath) throws IOException {
		//只要能把训练集的数据放入trainInstances即InstanceList中即可
		// model 创建模型
		CRF crf = new CRF(trainInstances.getDataAlphabet(),trainInstances.getTargetAlphabet());

		// construct the finite state machine
		// 创建有限状态机
		crf.addFullyConnectedStatesForLabels();
		// initialize model's weights 初始化模型的权重
		crf.setWeightsDimensionAsIn(trainInstances, false);

		// CRFOptimizableBy* objects (terms in the objective function)
		// objective 1: label likelihood objective 目标函数是标签的似然函数其中添加了高斯或者双曲线先验概率
		CRFOptimizableByLabelLikelihood optLabel = new CRFOptimizableByLabelLikelihood(crf, trainInstances);

		// CRF trainer 求导
		Optimizable.ByGradientValue[] opts = new Optimizable.ByGradientValue[] { optLabel };
		// by default, use L-BFGS as the optimizer 用BFGS算法计算参数
		CRFTrainerByValueGradients crfTrainer = new CRFTrainerByValueGradients(crf,opts);

		// CRFTrainerByLabelLikelihood crfTrainer =
		// new CRFTrainerByLabelLikelihood(crf);
		// crfTrainer.setGaussianPriorVariance(10.0);

		// all setup done, train until convergence 训练直到收敛
		// crfTrainer.setMaxResets(0);

		crfTrainer.train(trainInstances, iterations);
		System.out.println("*********************************");
		System.out.println("trainInstances:"+trainInstances.size());
		//
		// // save the trained model (if CRFWriter is not used)保存模型
		FileOutputStream fos = new FileOutputStream(modelPath+"/ner_crf.model");
		ObjectOutputStream oos = new ObjectOutputStream(fos);
		oos.writeObject(crf);
	}

	/**
	 * 测试模型
	 *
	 * @param trainingData
	 * @param testingData
	 * @throws IOException
	 * @throws ClassNotFoundException
	 */
	public void testModel(String outputDir,String modelPath,boolean flag) throws IOException, ClassNotFoundException {

		String[] labels = new String[] { "B-S", "E-S", "B-D", "E-D", "B-I", "E-I" };
		TransducerEvaluator evaluator = new MultiSegmentationEvaluator(
				new InstanceList[] { trainInstances, testInstances }, new String[] {
						"train", "test" }, labels, labels) {
			@Override
			public boolean precondition(TransducerTrainer tt) {
				// evaluate model every 5 training iterations
				return tt.getIteration() % 5 == 0;
			}
		};
		FileInputStream fos = new FileInputStream(modelPath+"/ner_crf.model");
		ObjectInputStream oos = new ObjectInputStream(fos);
		CRF crf = (CRF) oos.readObject();

		evaluator.evaluateInstanceList(new NoopTransducerTrainer(crf), testInstances,"Testing");
		System.out.println(testInstances.size()+"测试实例的大小"+testmap.size()+"testmap的大小");
		for (int i = 0; i < testInstances.size(); i++) {
			File file=(File)testmap.get(i);
			String filename=file.getName();
			File outputFile = new File(outputDir+"/"+filename);
			BufferedWriter bw = new BufferedWriter(new FileWriter(outputFile));          
			Sequence<?> input = (Sequence<?>) testInstances.get(i).getData();
			@SuppressWarnings("rawtypes")
			Sequence[] outputs = apply(crf, input, 1);//这个函数的意思
			int k = outputs.length;
			String fileId=file.getName();
			int index=fileId.lastIndexOf(".");
			fileId=fileId.substring(0,index);
			System.out.println(fileId+"文件名称");
//			
			String chunkContent="";
			for (int j = 0; j < input.size(); j++) {				
				StringBuffer sb = new StringBuffer("");				
				String[] data = input.get(j).toString().split("\n");
				for (String s : data) {
					if (s.length() == 1) {
						sb.append(s);
					}
				}
				for (int a = 0; a < k; a++)
					sb.append("*" + "\t" + outputs[a].get(j).toString());

				chunkContent=chunkContent+sb.toString()+"\r\n";
				bw.write(sb.toString());
				bw.newLine();
			}
			bw.flush();
			bw.close();
//			if(flag){
//				CRFResultDataInfo entity=new CRFResultDataInfo();
//				entity.setChunkId(fileId);
//				entity.setContent(chunkContent);
//				entity.setCrfFlag(0);
//				try {
//					CRF_ZHONGWEN_YUAN_DB.doInsertCRFDataInfoResult(entity);
//				} catch (Exception e) {
//					e.printStackTrace();
//				}
//			}
		}
	}

	@SuppressWarnings("rawtypes")
	public static Sequence[] apply(Transducer model, Sequence<?> input, int k) {
		Sequence[] answers;
		if (k == 1) {
			answers = new Sequence[1];
			answers[0] = model.transduce(input);
		} else {
			MaxLatticeDefault lattice = new MaxLatticeDefault(model, input, null, 100);

			answers = lattice.bestOutputSequences(k).toArray(new Sequence[0]);
		}
		return answers;
	}

}
